/* ************************************************************************
 * Copyright 2016 Advanced Micro Devices, Inc.
 *
 * ************************************************************************ */
#include <hip/hip_runtime.h>


#include "rocblas.h"
#include "definitions.h"


//do not use fma which is 50% slower than regular fmaf
#define fmaf(a, b, c) (a) * (b) + (c)

#define  M6x6 \
            rA[0][0] = lA[offA + 0];				  \
            rA[0][1] = lA[offA + 16];				  \
            rA[0][2] = lA[offA + 32];				  \
            rA[0][3] = lA[offA + 48];				  \
            rA[0][4] = lA[offA + 64];				  \
            rA[0][5] = lA[offA + 80];				  \
            rB[0][0] = lB[offB + 0];				  \
            rB[0][1] = lB[offB + 16];				  \
            rB[0][2] = lB[offB + 32];				  \
            rB[0][3] = lB[offB + 48];				  \
            rB[0][4] = lB[offB + 64];				  \
            rB[0][5] = lB[offB + 80];				  \
            offA += 97;								  \
            offB += 97;								  \
            rC[0][0]=fmaf(rA[0][0],rB[0][0],rC[0][0]); \
            rC[1][0]=fmaf(rA[0][1],rB[0][0],rC[1][0]); \
            rC[2][0]=fmaf(rA[0][2],rB[0][0],rC[2][0]); \
            rC[3][0]=fmaf(rA[0][3],rB[0][0],rC[3][0]); \
            rC[4][0]=fmaf(rA[0][4],rB[0][0],rC[4][0]); \
            rC[5][0]=fmaf(rA[0][5],rB[0][0],rC[5][0]); \
            rC[0][1]=fmaf(rA[0][0],rB[0][1],rC[0][1]); \
            rC[1][1]=fmaf(rA[0][1],rB[0][1],rC[1][1]); \
            rC[2][1]=fmaf(rA[0][2],rB[0][1],rC[2][1]); \
            rC[3][1]=fmaf(rA[0][3],rB[0][1],rC[3][1]); \
            rC[4][1]=fmaf(rA[0][4],rB[0][1],rC[4][1]); \
            rC[5][1]=fmaf(rA[0][5],rB[0][1],rC[5][1]); \
            rC[0][2]=fmaf(rA[0][0],rB[0][2],rC[0][2]); \
            rC[1][2]=fmaf(rA[0][1],rB[0][2],rC[1][2]); \
            rC[2][2]=fmaf(rA[0][2],rB[0][2],rC[2][2]); \
            rC[3][2]=fmaf(rA[0][3],rB[0][2],rC[3][2]); \
            rC[4][2]=fmaf(rA[0][4],rB[0][2],rC[4][2]); \
            rC[5][2]=fmaf(rA[0][5],rB[0][2],rC[5][2]); \
            rC[0][3]=fmaf(rA[0][0],rB[0][3],rC[0][3]); \
            rC[1][3]=fmaf(rA[0][1],rB[0][3],rC[1][3]); \
            rC[2][3]=fmaf(rA[0][2],rB[0][3],rC[2][3]); \
            rC[3][3]=fmaf(rA[0][3],rB[0][3],rC[3][3]); \
            rC[4][3]=fmaf(rA[0][4],rB[0][3],rC[4][3]); \
            rC[5][3]=fmaf(rA[0][5],rB[0][3],rC[5][3]); \
            rC[0][4]=fmaf(rA[0][0],rB[0][4],rC[0][4]); \
            rC[1][4]=fmaf(rA[0][1],rB[0][4],rC[1][4]); \
            rC[2][4]=fmaf(rA[0][2],rB[0][4],rC[2][4]); \
            rC[3][4]=fmaf(rA[0][3],rB[0][4],rC[3][4]); \
            rC[4][4]=fmaf(rA[0][4],rB[0][4],rC[4][4]); \
            rC[5][4]=fmaf(rA[0][5],rB[0][4],rC[5][4]); \
            rC[0][5]=fmaf(rA[0][0],rB[0][5],rC[0][5]); \
            rC[1][5]=fmaf(rA[0][1],rB[0][5],rC[1][5]); \
            rC[2][5]=fmaf(rA[0][2],rB[0][5],rC[2][5]); \
            rC[3][5]=fmaf(rA[0][3],rB[0][5],rC[3][5]); \
            rC[4][5]=fmaf(rA[0][4],rB[0][5],rC[4][5]); \
            rC[5][5]=fmaf(rA[0][5],rB[0][5],rC[5][5]); \

             //__threadfence_block(); \ does not compile


template<typename T, rocblas_int NB>
__global__ void trmm_left_lower_nontrans_MX096_NX096_KX16(hipLaunchParm lp,
    rocblas_fill uplo,
    rocblas_operation transA,
    rocblas_diagonal diag,
    rocblas_int M, rocblas_int N,
    const T *alpha,
    const T *A, rocblas_int lda,
    const T *B, rocblas_int ldb,
    T *C, rocblas_int ldc)
{

    T rC[6][6]  = { {(T)0} };
    T rA[1][6];
    T rB[1][6];

    __shared__ T lA[1552];
    __shared__ T lB[1552];

    T *plA, *plB;

    uint gidx = hipBlockIdx_x;
    uint gidy = hipBlockIdx_y;
    uint idx = hipThreadIdx_x; //get_local_id(0);
    uint idy = hipThreadIdx_y; //get_local_id(1);

    A +=  gidx*96+ idx + idy*lda;
    B +=  gidy*96*ldb+ idx + idy*ldb;

    uint block_k = K >> 4;
    do {
        plA = lA + idy*97+idx;
        plB = lB + idx*97+idy;

        plB[0] = B[0];
        plB[16] = B[16*ldb];
        plB[32] = B[32*ldb];
        plB[48] = B[48*ldb];
        plB[64] = B[64*ldb];
        plB[80] = B[80*ldb];

	    plA[0] = A[0+0*lda];
        plA[16] = A[16+0*lda];
        plA[32] = A[32+0*lda];
        plA[48] = A[48+0*lda];
        plA[64] = A[64+0*lda];
        plA[80] = A[80+0*lda];


        __syncthreads();

        uint offA = idx;
        uint offB = idy;

        M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6
	      M6x6

        A += lda<<4;
        B += 16;
	} while (--block_k > 0);

    C+= gidx*96+idx;
    C+= gidy*96*ldc;
    C+= idy*ldc;

    C[0*ldc] = alpha*rC[0][0] ;
    C[16*ldc] = alpha*rC[0][1] ;
    C[32*ldc] = alpha*rC[0][2] ;
    C[48*ldc] = alpha*rC[0][3] ;
    C[64*ldc] = alpha*rC[0][4] ;
    C[80*ldc] = alpha*rC[0][5] ;
    C+=16;
    C[0*ldc] = alpha*rC[1][0] ;
    C[16*ldc] = alpha*rC[1][1] ;
    C[32*ldc] = alpha*rC[1][2] ;
    C[48*ldc] = alpha*rC[1][3] ;
    C[64*ldc] = alpha*rC[1][4] ;
    C[80*ldc] = alpha*rC[1][5] ;
    C+=16;
    C[0*ldc] = alpha*rC[2][0] ;
    C[16*ldc] = alpha*rC[2][1] ;
    C[32*ldc] = alpha*rC[2][2] ;
    C[48*ldc] = alpha*rC[2][3] ;
    C[64*ldc] = alpha*rC[2][4] ;
    C[80*ldc] = alpha*rC[2][5] ;
    C+=16;
    C[0*ldc] = alpha*rC[3][0] ;
    C[16*ldc] = alpha*rC[3][1] ;
    C[32*ldc] = alpha*rC[3][2] ;
    C[48*ldc] = alpha*rC[3][3] ;
    C[64*ldc] = alpha*rC[3][4] ;
    C[80*ldc] = alpha*rC[3][5] ;
    C+=16;
    C[0*ldc] = alpha*rC[4][0] ;
    C[16*ldc] = alpha*rC[4][1] ;
    C[32*ldc] = alpha*rC[4][2] ;
    C[48*ldc] = alpha*rC[4][3] ;
    C[64*ldc] = alpha*rC[4][4] ;
    C[80*ldc] = alpha*rC[4][5] ;
    C+=16;
    C[0*ldc] = alpha*rC[5][0] ;
    C[16*ldc] = alpha*rC[5][1];
    C[32*ldc] = alpha*rC[5][2];
    C[48*ldc] = alpha*rC[5][3];
    C[64*ldc] = alpha*rC[5][4];
    C[80*ldc] = alpha*rC[5][5];

}

/*! \brief BLAS Level 3 API

    \details

    trmm solves

    C := alpha*op( A )*B,   or   C := alpha*B*op( A )

    where  alpha  is a scalar,  B  is an m by n matrix,  A  is a unit, or
    non-unit,  upper or lower triangular matrix  and  op( A )  is one  of

        op( A ) = A   or   op( A ) = A^T   or   op( A ) = A^H.

    @param[in]
    handle    rocblas_handle.
              handle to the rocblas library context queue.

    @param[in]
    side    rocblas_side.
            rocblas_side_left:       C := alpha*op( A )*B.
            rocblas_side_right:      C := alpha*B*op( A ).

    @param[in]
    uplo    rocblas_fill.
            rocblas_fill_upper:  A is an upper triangular matrix.
            rocblas_fill_lower:  A is a  lower triangular matrix.

    @param[in]
    transA  rocblas_operation.
            transB:    op(A) = A.
            rocblas_operation_transpose:      op(A) = A^T.
            rocblas_operation_conjugate_transpose:  op(A) = A^H.

    @param[in]
    diag    rocblas_diagonal.
            rocblas_diagonal_unit:      A is assumed to be unit triangular.
            rocblas_diagonal_non_unit:  A is not assumed to be unit triangular.

    @param[in]
    m       rocblas_int.
            m specifies the number of rows of B. m >= 0.

    @param[in]
    n       rocblas_int.
            n specifies the number of columns of B. n >= 0.

    @param[in]
    alpha
            alpha specifies the scalar alpha. When alpha is
            zero then A is not referenced and B need not be set before
            entry.

    @param[in]
    A       pointer storing matrix A on the GPU.
            of dimension ( lda, k ), where k is m
            when  rocblas_side_left  and
            is  n  when  rocblas_side_right
            only the upper/lower triangular part is accessed.

    @param[in]
    lda     rocblas_int.
            lda specifies the first dimension of A.
            if side = rocblas_side_left,  lda >= max( 1, m ),
            if side = rocblas_side_right, lda >= max( 1, n ).

    @param[in]
    B       pointer storing matrix B on the GPU.

    @param[in]
    ldb    rocblas_int.
           ldb specifies the first dimension of B. ldb >= max( 1, m ).

    @param[in,output]
    C       pointer storing matrix C on the GPU.

    @param[in]
    ldc    rocblas_int.
           ldb specifies the first dimension of C. ldc >= max( 1, m ).

    ********************************************************************/

#define NB_X 16

template<typename T>
rocblas_status
rocblas_trmm_template(rocblas_handle handle,
    rocblas_side side,
    rocblas_fill uplo,
    rocblas_operation transA,
    rocblas_diagonal diag,
    rocblas_int M, rocblas_int N,
    const T *alpha,
    const T *A, rocblas_int lda,
    const T *B, rocblas_int ldb,
    T *C, rocblas_int ldc)
{
    rocblas_int A_row = ( side == rocblas_side_left ?  M : N);

    if(handle == nullptr)
        return rocblas_status_invalid_handle;
    else if ( M < 0 )
        return rocblas_status_invalid_size;
    else if ( N < 0 )
        return rocblas_status_invalid_size;
    else if ( alpha == nullptr )
        return rocblas_status_invalid_pointer;
    else if ( A == nullptr )
        return rocblas_status_invalid_pointer;
    else if ( lda < A_row )
        return rocblas_status_invalid_size;
    else if ( B == nullptr )
        return rocblas_status_invalid_pointer;
    else if ( ldb < M )
        return rocblas_status_invalid_size;
    else if ( C == nullptr )
        return rocblas_status_invalid_pointer;
    else if ( ldc < M )
        return rocblas_status_invalid_size;

    /*
     * Quick return if possible.
     */

    if ( M == 0 || N == 0 )
        return rocblas_status_success;

    if(transA == rocblas_operation_transpose){
        return rocblas_status_not_implemented;
    }

    rocblas_int blocks_x = (M-1)/(NB_X*6) + 1;
    rocblas_int blocks_y = (N-1)/(NB_X*6) + 1;

    dim3 grid(blocks_x, blocks_y, 1);
    dim3 threads(NB_X, NB_X, 1);

    hipStream_t rocblas_stream;
    RETURN_IF_ROCBLAS_ERROR(rocblas_get_stream(handle, &rocblas_stream));
    T alpha_scalar = *alpha;

    hipLaunchKernel(HIP_KERNEL_NAME(trmm_Col_NN_B1_MX096_NX096_KX16<T, NB_X>), dim3(grid), dim3(threads), 0, rocblas_stream, M, N, K, alpha_scalar, A, lda, B, ldb, C, ldc);

    return rocblas_status_success;

}


/* ============================================================================================ */

    /*
     * ===========================================================================
     *    template interface
     *    template specialization
     * ===========================================================================
     */


template<>
rocblas_status
rocblas_trmm<float>(rocblas_handle handle,
    rocblas_operation transA,
    rocblas_operation transB,
    rocblas_int M, rocblas_int N, rocblas_int K,
    const float *alpha,
    const float *A, rocblas_int lda,
    const float *B, rocblas_int ldb,
    float *C, rocblas_int ldc)
{
    return rocblas_trmm_template<float>(handle, transA, transB, M, N, K, alpha, A, lda, B, ldb, C, ldc);
}

template<>
rocblas_status
rocblas_trmm<double>(rocblas_handle handle,
    rocblas_operation transA,
    rocblas_operation transB,
    rocblas_int M, rocblas_int N, rocblas_int K,
    const double *alpha,
    const double *A, rocblas_int lda,
    const double *B, rocblas_int ldb,
    double *C, rocblas_int ldc)
{
    return rocblas_trmm_template<double>(handle, transA, transB, M, N, K, alpha, A, lda, B, ldb, C, ldc);
}

/* ============================================================================================ */

    /*
     * ===========================================================================
     *    C wrapper
     * ===========================================================================
     */




/* ============================================================================================ */
